local parse_core = require "core"
local buildin_types = parse_core.buildin_types

local gmatch = string.gmatch
local tsort = table.sort
local tconcat = table.concat
local sformat = string.format

--table print start
local print = print
--local tconcat = table.concat
local tinsert = table.insert
local srep = string.rep
local type = type
local pairs = pairs
local tostring = tostring
local next = next
 
function print_r(root)
  local cache = {  [root] = "." }
  local function _dump(t,space,name)
    local temp = {}
    for k,v in pairs(t) do
      local key = tostring(k)
      if cache[v] then
        tinsert(temp,"+" .. key .. " {" .. cache[v].."}")
      elseif type(v) == "table" then
        local new_key = name .. "." .. key
        cache[v] = new_key
        tinsert(temp,"+" .. key .. _dump(v,space .. (next(t,k) and "|" or " " ).. srep(" ",#key),new_key))
      else
        tinsert(temp,"+" .. key .. " [" .. tostring(v).."]")
      end
    end
    return tconcat(temp,"\n"..space)
  end
  print(_dump(root, "",""))
end
--table print end
local mt = {}
mt.__index = mt

local function upper_head(s)
  local c =  string.upper(string.sub(s, 1, 1))
  return c..string.sub(s, 2)
end

local function create_stream()
  return setmetatable({}, mt)
end

function mt:write(s, deep)
  s = s or ""
  deep = deep or 0

  local prefix = ""
  for i=1,deep do
    prefix = prefix.."\t"
  end

  self[#self+1] = prefix..s
end

function mt:dump()
  return tconcat( self, "\n")
end



local function str_split(str, sep)
  if sep == nil then
    sep = "%s"
  end

  local t={}
  local i=1
  for v in gmatch(str, "([^"..sep.."]+)") do
    t[i] = v
    i = i + 1
  end
  return t
end

local _class_type = {
  string = "string",
  integer = "Int64",
  boolean = "bool",
}


local header = [[// Generated by sprotodump. DO NOT EDIT!]]
local headerIncludeCode = [[
#pragma once

#include <string>
#include <vector>
#include "sprotomessage.h"

]]

local cppIncludeCode = [[
#include <iostream>
#include <string.h>
#include "Protocol.h"
]]

local function type2class(type_name, class_name, sproto_type)
  local class = {
    class_name = class_name,
    type_name = type_name, 
    --max_field_count = sproto_type and _get_max_field_count(sproto_type) or nil,
    sproto_type = sproto_type,
    internal_class = {},
  }

  return class
end

local function gen_type_class(ast)
  local type_name_list = {}
  local class = {}
  local cache = {}

  for k, _ in pairs(ast) do
    type_name_list[#type_name_list+1] = k
  end
  tsort(type_name_list, function (a, b) return a<b end)
  print("name list is ---------")
  print_r(type_name_list)
  for i=1, #type_name_list do
    local k = type_name_list[i]
    local type_list = str_split(k, ".")
    
    local cur = class
    local type_name = ""
    for i=1,#type_list do
      local class_name = type_list[i]

      if i == 1 then type_name = class_name 
      else type_name = type_name.."."..class_name end

      if not cache[type_name] then
        local sproto_type = ast[type_name]
        local class_info =  type2class(sproto_type and type_name or nil, class_name, sproto_type)
        cur[#cur+1] = class_info
        cache[type_name] = class_info
      end

      cur = cache[type_name].internal_class
    end

  end

  return class
end

--generate sprototype namespace
local function _gen_sprototype_namespace(package)
  return upper_head(package).."SprotoType"
end

local _class_type = {
  string = "std::string",
  integer = "int64_t",
  boolean = "bool",
}

--generate type code
local function _2class_type(t, is_array, key)
  t = _class_type[t] or t

  if is_array and key then -- map
    local tk = _class_type[key.typename]
    assert(tk , "Invalid map key.")
    return string.format("std::unordered_map<%s, %s>;",tk,t)
    --return string.format("Dictionary<%s, %s>", tk, t)
  elseif is_array and not key then -- arrat
    return "std::vector<"..t.."*>"
    --return "List<"..t..">"
  elseif not is_array and not key then -- element
    return t
  else
    error("Invalid field type.")
  end
end

local _function_Declare = 
{
  string = {"virtual const char* GetStringField(int tagIndex, int index, int& len);",
              "virtual bool SetStringField(int tagIndex, int index, const char* value, int len);\n"},
  integer = {"virtual int GetIntegerField(int tagIndex, int index, int64_t& value);",
              "virtual bool SetIntegerField(int tagIndex, int index, int64_t value);\n"},
  boolean = {"virtual int GetBooleanField(int tagIndex, int index, bool& value);",
              "virtual bool SetBooleanField(int tagIndex, int index, bool value);\n"},
  struct = {"virtual SprotoMessage* GetStructField(int tagIndex, int index);",
              "virtual SprotoMessage* SetStructField(int tagIndex, int index);\n"}
}

local _function_Implemetion = 
{
  string = {"const char* %s::GetStringField(int tagIndex, int index, int& len){",
              "bool %s::SetStringField(int tagIndex, int index, const char* value, int len){"},
  integer = {"int %s::GetIntegerField(int tagIndex, int index, int64_t& value){",
              "bool %s::SetIntegerField(int tagIndex, int index, int64_t value){"},
  boolean = {"int %s::GetBooleanField(int tagIndex, int index, bool& value){",
              "bool %s::SetBooleanField(int tagIndex, int index, bool value){"},
  struct = {"SprotoMessage* %s::GetStructField(int tagIndex, int index){",
              "SprotoMessage* %s::SetStructField(int tagIndex, int index){"}
}

local _function_subImp = 
{
  string = {[[
    if (tagIndex == %d)
    {
        if (%s_.empty())
        {
            len = 0;
            return NULL;
        }
        else
        {
            len = %s_.length();
            return &%s_[0];
        }
    }else]],[[
    {
        return NULL;
    }]], [[
    if (tagIndex == %d)
    {
        %s_.assign(value, len);
        return true;
    }else]],[[
    {
        return false;
    }]]},
    integer = {[[
    if (tagIndex == %d)
    {
        if(index == -1)
        {
            if(b%s_NeedEncode_)
            {
                value = %s_;
                return 1;
            }
            return 0;
        }
        else
        {
            if(b%s_NeedEncode_)
            {
                value = %s_;
                return 1;
            }
            return 0;
        }
        
    }else]],[[
    {
        return 0;
    }]], [[
    if (tagIndex == %d)
    {
        %s_ = value;
        b%s_NeedEncode_ = true;
        return true;
    }else]],[[
    {
        return false;
    }]]},
    boolean = {[[
    if (tagIndex == %d)
    {
        if(index == -1)
        {
            if(b%s_NeedEncode_)
            {
                value = %s_;
                return 1;
            }
            return 0;
        }
        else
        {
             if(b%s_NeedEncode_)
            {
                value = %s_;
                return 1;
            }
            return 0;
        }
    }else]],[[
    {
        return 0;
    }]], [[
    if (tagIndex == %d)
    {
        %s_ = value;
        b%s_NeedEncode_ = true;
        return true;
    }else]],[[
    {
        return false;
    }]]},
    struct = {
    [[
    if (tagIndex == %d) 
    {
        if (index >= 0 && index < (int)%s_.size())
            return %s_[index];
    }else]],
    [[
    {
        return NULL;
    }]],
    [[
    if (tagIndex == %d)
    {
        %s* msg = new %s;
        %s_.push_back(msg);
        return msg;
    }else]],
    [[
    {
        return NULL;
    }
    ]]
  }    
}

local _getMessage_Imp = [[
std::string %s::GetMessageName(){
    static std::string NameStr = "%s";
    return NameStr;
}
]]

local _cpp_DestructorStr = [[
    for (auto it = %s_.begin();
      it != %s_.end(); ++it)
    {
        auto pVar = *it;
        delete pVar;
    }
]]

--namespaceOuter is for getmessagename function as request is in the namespace, or it will just return "request"
local function dump_class(class_info, streamHeader, headerDeep, streamCpp, cppDeep, namespaceOuter)
  local class_name = class_info.class_name
  local sproto_type = class_info.sproto_type
  local internal_class = class_info.internal_class
  --local max_field_count = class_info.max_field_count

  if sproto_type then
    streamHeader:write("class "..class_name.." : public  SprotoMessage {", headerDeep)
    
    -- max_field_count
    headerDeep = headerDeep + 1;
    --stream:write("private static int max_field_count = "..(max_field_count)..";", deep)

    -- internal class
    streamHeader:write("", headerDeep)
    for i=1,#internal_class do
      dump_class(internal_class[i], streamHeader, headerDeep, streamCpp, cppDeep)
    end

    streamHeader:write("public:", headerDeep-1)
    --constructor function
    streamHeader:write(class_name.."();", headerDeep)
    streamHeader:write("virtual ~"..class_name.."();", headerDeep)
    streamHeader:write("virtual std::string GetMessageName();", headerDeep)

    local hasTypeDeclare = {string = {},integer = {}, boolean = {}, struct = {}}
    -- property
    streamHeader:write("", headerDeep)
    
   

    for i=1,#sproto_type do
      local field = sproto_type[i]
      local type = _2class_type(field.typename, field.array, field.key)
      local name = field.name
      local tag = field.tag
      print("all name, array key tag is ", field.typename, field.array, field.key)
      --set type declare code
      if(field.array) then
        -- count struct counts
        hasTypeDeclare.struct[#hasTypeDeclare.struct + 1] = {tag, name, field.typename }
      else
        hasTypeDeclare[field.typename][ #hasTypeDeclare[field.typename] + 1] = {tag, name}
      end
   
      streamHeader:write("public:", headerDeep-1)
      --lin: generate set&get method for all properties
      if field.array == nil then
          streamHeader:write(sformat("void Set%s(const %s& %s) {%s_ = %s; b%s_NeedEncode_ = true;}",upper_head(name), type, name, name, name, name), headerDeep)
          streamHeader:write(sformat("bool Is%s_NeedEncode(){ return b%s_NeedEncode_;}\n", name, name), headerDeep)
      else
          streamHeader:write(sformat("void Set%s(const %s& %s) {%s_ = %s;}",upper_head(name), type, name, name, name), headerDeep)
          streamHeader:write(sformat("void Add%s(%s* %s){ %s_.push_back(%s);}\n", field.typename,field.typename, name, name, name), headerDeep)
      end
      streamHeader:write(sformat("%s& Get%s() {return %s_;}",type, upper_head(name), name), headerDeep)
      streamHeader:write("private:", headerDeep-1)
      streamHeader:write(sformat("%s %s_;\n",type, name), headerDeep)
      if field.array == nil then
        streamHeader:write(sformat("bool b%s_NeedEncode_;\n", name), headerDeep)
      end
    end

    streamHeader:write("public:", headerDeep-1)

    for methodType,typeInfo in pairs(hasTypeDeclare) do
      if(#typeInfo > 0) then
        streamHeader:write(_function_Declare[methodType][1], headerDeep)
        streamHeader:write(_function_Declare[methodType][2], headerDeep)

        --get method implemetion
        streamCpp:write(string.format(_function_Implemetion[methodType][1], class_name))
        streamCpp:write("", cppDeep + 1)
        print_r(typeInfo)
        for index, nameTable in pairs(typeInfo) do
          if methodType ~= "struct" then
              streamCpp:write(string.format(_function_subImp[methodType][1], nameTable[1], nameTable[2], nameTable[2], nameTable[2], nameTable[2]))
          else -- nameTable[1] is name, nameTable[2] is struct type
              streamCpp:write(string.format(_function_subImp[methodType][1], nameTable[1], nameTable[2], nameTable[2]))
          end
        end
        
        
        streamCpp:write(_function_subImp[methodType][2])
        if methodType == "struct" then
            streamCpp:write(_function_subImp[methodType][2])
        end

        streamCpp:write("}")

        --set implemention
        streamCpp:write(string.format(_function_Implemetion[methodType][2], class_name))
        streamCpp:write("", cppDeep + 1)

        for index, nameTable in pairs(typeInfo) do
          if methodType ~= "struct" then
              streamCpp:write(string.format(_function_subImp[methodType][3], nameTable[1], nameTable[2], nameTable[2]))
          else
              streamCpp:write(string.format(_function_subImp[methodType][3], nameTable[1], nameTable[3], nameTable[3], nameTable[2]))
          end
        end
        
        streamCpp:write(_function_subImp[methodType][4])

        if methodType == "struct" then
            streamCpp:write(_function_subImp[methodType][4])
        end

        streamCpp:write("}")
      end
    end

    --getMessage Implemention
    streamCpp:write("", cppDeep)
    if namespaceOuter then
        streamCpp:write(string.format(_getMessage_Imp, class_name, namespaceOuter.."."..class_name))
    else
        streamCpp:write(string.format(_getMessage_Imp, class_name, class_name))
    end    
    
    --getMessage Implemention end

    --cpp construction implemention
    streamCpp:write("", cppDeep)
    streamCpp:write(string.format("%s::%s()", class_name, class_name))
    local initStr = ":"
    for typeIt,nameValue in pairs(hasTypeDeclare) do
        if (#nameValue > 0) then
          for k, propertyName in pairs(nameValue) do
            if typeIt == "integer" then
                initStr = initStr .. propertyName[2] .. "_(0), b" .. propertyName[2] .. "_NeedEncode_(false), "
            elseif typeIt == "boolean" then
                initStr = initStr .. propertyName[2] .. "_(false), b" .. propertyName[2] .. "_NeedEncode_(false), "
            elseif typeIt == "struct" then
            
                --initStr = initStr .. "b" .. propertyName[1] .. "_NeedEncode_(false), "
                              
            else
                initStr = initStr .. "b" .. propertyName[2] .. "_NeedEncode_(false), "
            end
          end
        end
        print("cpp init str: " .. initStr)
    end
    initStr = string.sub(initStr, 1, -3)
    streamCpp:write(initStr.."\n{}", cppDeep)
  --cpp construction implemention end
    --cpp destructor implemention
    streamCpp:write("", cppDeep)
    streamCpp:write(string.format("%s::~%s(){", class_name, class_name))
    for typeIt,nameValue in pairs(hasTypeDeclare) do
        if typeIt == "struct" and (#nameValue > 0) then
          local initStr = ""
          for k, propertyName in pairs(nameValue) do
              initStr = initStr .. string.format(_cpp_DestructorStr, propertyName[2], propertyName[2])
          end
          streamCpp:write(initStr, cppDeep)
        end
    end

    streamCpp:write("}")
      --cpp destructor implemention end
   
    headerDeep = headerDeep - 1;
    streamHeader:write("};\n\n", headerDeep)

  else
    streamHeader:write("namespace "..class_name.." {", headerDeep)
    streamCpp:write("namespace "..class_name.." {", cppDeep)
    -- internal class
    streamHeader:write("", headerDeep)
    for i=1,#internal_class do
      dump_class(internal_class[i], streamHeader, headerDeep + 1, streamCpp, 1, class_name)
    end
    streamHeader:write("};\n\n", headerDeep)
    streamCpp:write("}\n\n", cppDeep)
  end
end

local function parse_type(class, streamHeader, streamCpp, package)
  if not class or #class == 0 then return end

  local namespace = _gen_sprototype_namespace(package)
  streamHeader:write("namespace "..namespace.." { ")
  
  streamCpp:write("namespace "..namespace.." { ")
  for i=1,#class do
    local class_info = class[i]
    streamHeader:write("class "..class_info.class_name..";")
  end
  for i=1,#class do
    local class_info = class[i]
    dump_class(class_info, streamHeader, 1, streamCpp, 1)
  end

  streamHeader:write("};\n\n")
  streamCpp:write("}\n\n")
  --print(streamCpp:dump())
  --print(stream:dump())
end

local util = require "util"
local LoadPbFileStr = [[
bool LoadPbfile(std::string& filename, std::string& pb)
{
  std::ifstream ifs(filename, std::ifstream::binary);
  if (!ifs.is_open())
    return false;

  ifs.seekg(0, ifs.end);
  int length = ifs.tellg();
  ifs.seekg(0, ifs.beg);
  pb.resize(length, ' ');

  char* begin = &*pb.begin();
  ifs.read(begin, length);
  ifs.close();

  return true;
}
]]
--generate all class code
local function parse_ast2type(ast, package, name)
  package = package or ""
  local type_class = gen_type_class(ast)
  print_r(type_class)
  local streamHeader = create_stream()
  local streamCpp = create_stream()

  streamHeader:write(header)
  streamHeader:write([[// source: ]]..(name or "input").."\n")
  streamHeader:write(headerIncludeCode)
  streamHeader:write("bool LoadPbfile(std::string& filename, std::string& pb);")

  streamCpp:write(header)
  streamCpp:write([[// source: ]]..(name or "input").."\n")
  streamCpp:write(cppIncludeCode)
  streamCpp:write(string.format("#include \"%s.h\"",util.path_basename(name)), deep)
  streamCpp:write(LoadPbFileStr,deep)
  --streamCpp:write("CppSproto* Sproto::Protocol::pSproto = NULL;",deep)
  -- parse type
  parse_type(type_class, streamHeader, streamCpp, package)

  return streamHeader:dump(), streamCpp:dump()
end

--generate protocol code
local function gen_protocol_class(ast)
  local ret = {}
  for k,v in pairs(ast) do
    ret[#ret+1] = {
      name = k, 
      tag = v.tag,
      request = v.request,
      response = v.response,
    }
  end
  tsort(ret, function (a, b) return a.name < b.name end)

  local cache = {}
  local classes = {}
  for i=1,#ret do
    local name = ret[i].name
    local fold = str_split(name, ".")

    local fullname = ""  
    local per = classes
    for i,v in ipairs(fold) do
      if i == 1 then fullname = v
      else fullname = fullname.."."..v end

      local item = cache[fullname]
      if not item then
        item = {}
        cache[fullname] = item
        table.insert(per, item)
      end

      per = item
      item.name = v
      item.value = ast[fullname]
    end
  end


  ret.classes = classes
  return ret
end

local function _gen_protocol_classname(package)
  return upper_head(package).."Protocol"
end

local function constructor_protocol(class, package, stream, deep)
  local class_name = _gen_protocol_classname(package)
  local type_namespace = _gen_sprototype_namespace(package)

  --stream:write("private "..class_name.."() {", deep)
  deep = deep + 1
    for _,class_info in ipairs(class) do
      local name = class_info.name
      print(name)
      local tag = class_info.tag
      local request_type = class_info.request
      local response_type = class_info.response
      local stag = name..".Tag"

      --stream:write("Protocol.SetProtocol<"..name.."> ("..stag..");", deep)
      stream:write(string.format("if(typeTag == %d){",tag), deep)
      local bNoRequesAndResponse = true
      if request_type then
        request_type = type_namespace.."."..request_type
        --stream:write("Protocol.SetRequest<"..request_type.."> ("..stag..");",deep)
        request_type = string.gsub(request_type, "%.", "::")
        --stream:write("if(bGenResponse == false){", deep+1)
        stream:write(string.format("spData = new %s();",request_type), deep+2)
        --streamProtoClassToTag:write(string.format("template<class %s*>",request_type),deep-1)
        --streamProtoClassToTag:write(string.format("static int Type2Int(){return %d;}",tag),deep-1)
        bNoRequesAndResponse = false
      end

      if response_type then
        response_type = type_namespace.."."..response_type
        response_type = string.gsub(response_type, "%.", "::");
        if(bNoRequesAndResponse) then
          stream:write("if(bGenResponse){", deep+1)
          stream:write("if(spData!=NULL){", deep+2)
          stream:write("delete spData; spData = NULL; }", deep+3)
          stream:write(string.format("spData = new %s();}",response_type), deep+2)
        else
          stream:write("else{", deep+1)
          stream:write(string.format("spData = new %s();}",response_type), deep+2)
        end
        --stream:write("Protocol.SetResponse<"..response_type.."> ("..stag..");", deep)
   
        --streamProtoClassToTag:write(string.format("template<class %s*>",response_type),deep-1)
        --streamProtoClassToTag:write(string.format("static int Type2Int(){return %d;}",tag),deep-1)
        bNoRequesAndResponse = false
      end
      stream:write()
      if(bNoRequesAndResponse == false) then
        stream:write("pSproto->Decode(spData, buf, size);", deep+1)
      end
      stream:write("}", deep)
    end
  deep = deep - 1
  --stream:write("}\n", deep)
end

local function dump_protocol_class(class, stream, deep, streamProtoClassToTag)
  local name = class.name
  local value = class.value

  stream:write("class "..name.." {", deep)
    if value then
      assert(#class == 0)
      --stream:write("public const int Tag = "..value.tag..";", deep+1)
      --stream:write("public:", deep+1)
      --stream:write("static const int Tag = "..value.tag..";", deep+2)
      streamProtoClassToTag:write("template<>",deep)
      streamProtoClassToTag:write(string.format("class Type2Int<class %s>{public: const int Tag = %d;};",name, value.tag),deep)
    else
      for i,v in ipairs(class) do
        dump_protocol_class(v, stream, deep+1,streamProtoClassToTag)
      end
    end
  stream:write("};\n", deep)
end

local function parse_protocol(class, stream, package)
    if not class or #class == 0 then return end

    local class_name = _gen_protocol_classname(package)
    stream:write("namespace Sproto{\n\tclass "..class_name.." {")

    stream:write("public:",1)  
    stream:write("static  CppSproto* pSproto;",2)
local InitFunc = [[static  bool Init() {
      if (pSproto == NULL) {
        pSproto = new CppSproto;
        std::string pb;
        std::string fileName = cocos2d::FileUtils::getInstance()->fullPathForFilename("sproto.spb");
        if (!LoadPbfile(fileName, pb)) {
          return false;
        }
        return pSproto->Init(pb.data(), pb.length());
      }
    }]]
    stream:write(InitFunc, 2)
    stream:write("static  SprotoMessage* ".." DecodeData(int typeTag,const char* buf,int size, bool bGenResponse = false){", 2)
    stream:write("SprotoMessage* spData = NULL;", 3)
    
    constructor_protocol(class, package, stream, 2)
    stream:write("return spData;",3)
    stream:write("}",2)
    stream:write("};",1)

    print("dump_protocol_class")
    local streamProtoClassToTag = create_stream()
    streamProtoClassToTag:write("template<class T>",2)
    streamProtoClassToTag:write("class Type2Int{public: const int Tag = -1;};",2)
    for i,v in ipairs(class.classes) do
      dump_protocol_class(v, stream, 2, streamProtoClassToTag)
    end

   
    stream:write(streamProtoClassToTag:dump())
    stream:write("}")
end

local function parse_ast2protocol(ast, package, name)
  package = package or ""
  local protocol_class = gen_protocol_class(ast)
  local stream = create_stream()

  stream:write(header)
  stream:write([[// source: ]]..(name or "input").."\n")
  stream:write("#pragma once\n#include \"cppsproto.h\"\n#include <iostream>\n#include <fstream>")
  stream:write(string.format("#include \"%s.h\"",util.path_basename(name)))
  --stream:write(using)

  -- parse protocol
  --parse_protocol(protocol_class, stream, package)

  return stream:dump()  
end

--[===[
local function parse_ast2all(ast, package, name)
  package = package or ""
  local type_class = gen_type_class(ast.type)
  local protocol_class = gen_protocol_class(ast.protocol)
  local stream = create_stream()

  stream:write(header)
  stream:write([[// source: ]]..(name or "input").."\n")
  --stream:write(using)

  parse_type(type_class, stream, package)
  parse_protocol(protocol_class, stream, package)

  return stream:dump()  
end
--]===]
------------------------------- dump -------------------------------------


local function main(trunk, build, param)
  local package = util.path_basename(param.package or "")--all packages
  local outfile = param.outfile--output file name
  local dir = param.dircetory or ""

  if outfile then
    local data = parse_ast2all(build, package, table.concat(param.sproto_file, " "))
    util.write_file(dir..outfile, data, "w")
  else
    -- dump sprototype
    for i,v in ipairs(trunk) do
      local name = param.sproto_file[i]
      local fileName = util.path_basename(name)
      local outh = fileName..".h"
      local outcpp = fileName..".cpp"
      local headerdata, cppdata = parse_ast2type(v.type, package, name)
      util.write_file(dir..outh, headerdata, "w")
      util.write_file(dir..outcpp, cppdata, "w")
    end

    -- dump protocol
    
    if build.protocol then
      local name = param.sproto_file[1]
      local data = parse_ast2protocol(build.protocol, package, name)
      local outfile = package.."Protocol.h"
      util.write_file(dir..outfile, data, "w")
    end
    
  end
end

return main
